import torch
import torch.utils.data
from torchvision import datasets
import numpy as np
from typing import Optional, Callable, Tuple, Any

def tensor_rot_90(x):
    return x.flip(2).transpose(1, 2)

def tensor_rot_180(x):
    return x.flip(2).flip(1)

def tensor_rot_270(x):
    return x.transpose(1, 2).flip(2)

def rotate_single_with_label(img, label):
    if label == 1:
        img = tensor_rot_90(img)
    elif label == 2:
        img = tensor_rot_180(img)
    elif label == 3:
        img = tensor_rot_270(img)
    return img

def rotate_batch_with_labels(batch, labels):
    images = []
    for img, label in zip(batch, labels):
        img = rotate_single_with_label(img, label)
        images.append(img.unsqueeze(0))
    return torch.cat(images)

def rotate_batch(batch, label='rand'):
    if label == 'rand':
        labels = torch.randint(4, (len(batch),), dtype=torch.long)
    else:
        assert isinstance(label, int)
        labels = torch.zeros((len(batch),), dtype=torch.long) + label
    return rotate_batch_with_labels(batch, labels), labels


class RotateImageFolder(datasets.ImageFolder):
    def __init__(self, traindir, train_transform, ):
        super(RotateImageFolder, self).__init__(traindir, train_transform)
        
    def __len__(self):
        return super(RotateImageFolder, self).__len__()
        
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img_input = self.loader(path)

        if self.transform is not None:
            img = self.transform(img_input)
        else:
            img = img_input

        target_ssh = np.random.randint(0, 4, 1)[0]
        
        img_ssh = rotate_single_with_label(img, target_ssh)
        return (img, img_ssh, target, target_ssh) 


class ExtendedRotatedImageFolder(datasets.ImageFolder):
    def __init__(self, root: str, batch_size: int = 1, steps_per_example: int = 1, minimizer = None, transform: Optional[Callable] = None, single_crop: bool = False, start_index: int = 0):
        super().__init__(root=root, transform=transform)
        self.batch_size = batch_size
        self.minimizer = minimizer
        self.steps_per_example = steps_per_example
        self.single_crop = single_crop
        self.start_index = start_index
    
    def __len__(self):
        mult = self.steps_per_example * self.batch_size
        mult *= (super().__len__() if self.minimizer is None else len(self.minimizer)) 
        return mult

    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        real_index = (index // self.steps_per_example) + self.start_index
        if self.minimizer is not None:
            real_index = self.minimizer[real_index]
        path, target = self.samples[real_index]
        sample = self.loader(path)
        if self.transform is not None and not self.single_crop:
            samples = torch.stack([self.transform(sample) for i in range(self.batch_size)], axis=0)
        elif self.transform and self.single_crop:
            s = self.transform(sample)
            samples = torch.stack([s for i in range(self.batch_size)], axis=0)
        if self.target_transform is not None:
            target = self.target_transform(target)
        samples_rot, target_rot = rotate_batch(samples)
        return samples, samples_rot, target, target_rot